import matplotlib.pyplot as plt
import torchvision.transforms
from torch.utils.data import Dataset
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import datasets
from torch.utils.data import DataLoader


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(20,1024),
            nn.Flatten(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(1024,1024),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512,10)
        )

        def forward(self, x):
            return self.net(x)


